Skip to content

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

FusedAttention supports "right" side sliding window attention for some time now. This adds support for SWA (left, right) with FusedAttention backend in TE.
(changes cherry-picked from original PR: #1369)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

transformer_engine

  • common

    • fused_attn
      • fused_attn.cpp
        • add bottom_right_diagonal parameter to the API
        • Edit the filters to allow sliding window config to pick arbitrary seqlen fused attn backend
      • fused_attn_f16_arbitrary_seqlen.cu: add bottom_right_diagonal parameter to the API
      • fused_attn_fp8.cu: add bottom_right_diagonal parameter to the FADescriptor_v1 API
      • utils.h: add bottom_right_diagonal parameter to FADescriptor_v1 API
  • pytorch

    • transformer.py
      • plumb bottom_right_diagonal through the call stack: TransformerLayer --> SelfAttention/CrossAttention
    • attention
      • dot_product_attention
        • backends.py:
          • UnfusedDotProductAttention
            • add bottom_right_diagonal parameter to the forward API
              • why is it not used in the forward?
                • bottom_right_alignment is being used in the Alibi call, perhaps this should be corrected
          • FusedAttn custom module
            • add bottom_right_diagonal parameter to the forward API
          • FusedAttention module
            • plumb bottom_right_diagonal through the call stack
        • dot_product_attention.py
          • DotProductAttention
            • Plumb bottom_right_diagonal through the call stack
            • Add calculation of bottom_right_diagonal if it's None
        • utils.py
          • AttentionParams
            • [x]
          • get_attention_backend
            • update sliding window filter section
            • update attention bias filter section
      • multi_head_attention.py
        • Add bottom_right_diagonal to forward API and call
        • Add calculation of bottom_right_diagonal if it's None
    • cpp_extentions
      • fused_attn.py
        • plumb bottom_right_diagonal in fused_attn_fwd/fused_attn_bwd
    • csrc
      • extension
        • attention.cpp
          • plumb bottom_right_diagonal through the call stack: fused_attn_fwd --> nvte_fused_attn_fwd
          • same as above for bwd
      • extensions.h
        • add bottom_right_diagonal to fused_attn_fwd and fused_attn_bwd API definitions

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Dec 4, 2025

Greptile Summary

  • Adds sliding window attention (SWA) support with configurable left/right diagonal alignment to the FusedAttention backend by introducing a bottom_right_diagonal parameter throughout the API stack
  • Extends backend selection logic to support arbitrary sequence length fused attention for sliding window configurations, enabling window patterns like (>=0, >=0) instead of just (-1,0) or (>=0,0)
  • Updates test configurations to validate bidirectional sliding window attention with both MHA and GQA attention types using window sizes like (512, 512)

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_fp8.cu Contains hardcoded true values as placeholders for bottom_right_diagonal parameter instead of proper parameter threading; incomplete implementation
transformer_engine/pytorch/transformer.py Parameter override logic has potential issue where explicit user values may be overridden based on mask type regardless of user intent
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Complex parameter resolution logic unconditionally overrides user-provided bottom_right_diagonal values based on mask type

Confidence score: 3/5

  • This PR requires careful review due to incomplete implementation in FP8 backend and potentially confusing parameter override behavior
  • Score lowered due to hardcoded placeholder values in FP8 implementation and parameter override logic that may violate user expectations when explicit values are provided
  • Pay close attention to fused_attn_fp8.cu, transformer.py, and dot_product_attention.py for proper parameter handling and user experience

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (2)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1281 (link)

    logic: Trailing comma creates single-element tuple instead of boolean - should this be just bottom_right_alignment = attn_mask_type not in ["causal", "padding_causal"]?

  2. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1482 (link)

    style: Uses hardcoded mask type check instead of the new bottom_right_diagonal parameter for ALiBi alignment. Should this use bottom_right_diagonal parameter for consistency?

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

15 files reviewed, 8 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +777 to +783
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Logic overrides the instance variable even when explicitly set in forward call - should preserve user's explicit choice. Should the mask type check override an explicitly passed bottom_right_diagonal parameter, or only apply when it's None?

Comment on lines +787 to +793
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Same logic issue as above - mask type check overrides explicit parameter values

Copy link
Collaborator

@cyanguwa cyanguwa Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a couple of SWA tests to the CP tests as well? I think it's just a matter of replacing (left,0) with (left, right) and test them out. Thanks!

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (4)

  1. transformer_engine/pytorch/transformer.py, line 777-783 (link)

    logic: Conditional logic overwrites explicitly passed parameter. When self_attn_mask_type in {"causal", "padding_causal"} at line 777, this unconditionally sets bottom_right_diagonal = False, even if the user explicitly passed a different value in the forward call. The check at line 779 (if bottom_right_diagonal is None) becomes unreachable in those cases.

    Consider restructuring to only apply defaults when bottom_right_diagonal is None:

  2. transformer_engine/pytorch/transformer.py, line 787-793 (link)

    logic: Same conditional logic issue for encoder-decoder attention. The mask type check at line 787 unconditionally overwrites enc_dec_bottom_right_diagonal, ignoring explicitly passed values.

  3. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 920 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  4. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 947 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

15 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

1 similar comment
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (1)

  1. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1283 (link)

    syntax: Tuple has trailing comma causing it to be a single-element tuple instead of boolean

15 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

15 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Signed-off-by: Sudhakar Singh <[email protected]>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L0

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  2. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)

    logic: Variable name mismatch: setting use_flash_attention instead of use_flash_attention_2

  3. transformer_engine/pytorch/transformer.py, line 777-783 (link)

    logic: Logic overrides user's explicit parameter. If user passes bottom_right_diagonal=True and mask type is causal, line 778 forces it to False, then line 783 forces it back to True if mask type contains "bottom_right". Should only apply defaults when parameter is None.

  4. transformer_engine/pytorch/transformer.py, line 787-793 (link)

    logic: Same logic issue: overrides user's explicit enc_dec_bottom_right_diagonal parameter

  5. transformer_engine/pytorch/attention/multi_head_attention.py, line 751-757 (link)

    logic: Same logic issue as transformer.py: overrides user's explicit bottom_right_diagonal parameter

  6. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1099-1105 (link)

    logic: Same logic issue: overrides user's explicit bottom_right_diagonal parameter

17 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Additional Comments (6)

  1. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 911 (link)

    syntax: Variable name mismatch - should be use_flash_attention_2 not use_flash_attention

  2. transformer_engine/pytorch/attention/dot_product_attention/utils.py, line 938 (link)

    syntax: Same variable name mismatch - should be use_flash_attention_2 not use_flash_attention

  3. transformer_engine/pytorch/transformer.py, line 777-778 (link)

    logic: Logic unconditionally overrides user-provided bottom_right_diagonal parameter. If user explicitly passes bottom_right_diagonal=True with self_attn_mask_type="causal", it gets overridden to False. Should the mask type check only apply when bottom_right_diagonal is None, not override explicit parameter values?

  4. transformer_engine/pytorch/transformer.py, line 787-788 (link)

    logic: Same issue - logic overrides explicit enc_dec_bottom_right_diagonal parameter values based on mask type. Should this only apply when the parameter is None?

  5. transformer_engine/pytorch/attention/multi_head_attention.py, line 751-752 (link)

    logic: Same logic issue as in transformer.py - mask type check overrides explicit bottom_right_diagonal values. The pattern of checking mask type should only apply when the parameter is None

  6. transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py, line 1098-1099 (link)

    logic: Consistent with other files - mask type check overrides explicit bottom_right_diagonal values instead of only applying when None

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

17 files reviewed, 6 comments

Edit Code Review Agent Settings | Greptile

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Signed-off-by: Sudhakar Singh <[email protected]>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 5 comments

Edit Code Review Agent Settings | Greptile

"Disabling FlashAttention as it only supports sliding window with bottom right"
" diagonal alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Variable name mismatch causing incorrect backend selection

This line sets use_flash_attention instead of use_flash_attention_2, which was the intended target based on the conditional check on line 906. This causes FlashAttention to be incorrectly disabled when it should only disable FlashAttention 2.

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

Impact: This bug causes FlashAttention to be incorrectly disabled for sliding window attention with cross-attention scenarios where bottom_right_diagonal=False and max_seqlen_q != max_seqlen_kv. The code should only disable FlashAttention 2 in this case, not all FlashAttention backends.

"Disabling FlashAttention as it only supports ALiBi with bottom right diagonal"
" alignment for cross-attention"
)
use_flash_attention = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Variable name mismatch causing incorrect backend selection

Similar to line 911, this line sets use_flash_attention instead of use_flash_attention_2. The conditional check on line 933 verifies not bottom_right_diagonal and max_seqlen_q != max_seqlen_kv for FlashAttention 2 with ALiBi, but then incorrectly disables all FlashAttention instead of just FlashAttention 2.

Suggested change
use_flash_attention = False
use_flash_attention_2 = False

Impact: This causes FlashAttention to be incorrectly disabled for ALiBi attention with cross-attention where bottom_right_diagonal=False and max_seqlen_q != max_seqlen_kv. Only FlashAttention 2 should be disabled in this scenario.

Comment on lines +1099 to +1105
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic issue: Explicit bottom_right_diagonal parameter is unconditionally overridden

When a user explicitly passes bottom_right_diagonal=False and attn_mask_type is "causal_bottom_right", the code first sets it to False on line 1100, but then immediately overrides it to True on line 1105. This means the explicit parameter value is ignored.

Problem scenario:

# User explicitly requests bottom_right_diagonal=False
forward(..., attn_mask_type="causal_bottom_right", bottom_right_diagonal=False)

# Line 1100: bottom_right_diagonal = False  (respects explicit value)
# Line 1105: bottom_right_diagonal = True   (overrides explicit value!)

The condition on line 1101 checks bottom_right_diagonal is None, which will be False after line 1100, but then the or clause still forces it to True.

Suggested fix: The logic should respect explicit values and only set defaults when bottom_right_diagonal is None:

if attn_mask_type in {"causal", "padding_causal"}:
    if bottom_right_diagonal is None:
        bottom_right_diagonal = False
elif bottom_right_diagonal is None:
    # Only set to True if not explicitly specified
    if attn_mask_type in {"causal_bottom_right", "padding_causal_bottom_right"}:
        bottom_right_diagonal = True
    else:
        bottom_right_diagonal = True  # default for other mask types

This same pattern appears in multiple locations and should be fixed consistently throughout the codebase.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some explanation in the init or fwd docstring that bottom_right_diagonal will be forced to True if the mask type has bottom_right in it. Just so users (and Greptile) don't get confused. It helps remove this comment from Greptile as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right! The current behavior where bottom_right_diagonal gets overridden based on the mask type should be clearly documented. This would help users understand when their explicit parameter values will be respected versus when they'll be overridden.

Looking at the current logic, I can see that:

  1. Line 1100: If mask type is "causal" or "padding_causal", it's forced to False
  2. Lines 1101-1105: If it's None OR mask type contains "bottom_right", it's forced to True

Here's a suggestion to add clarity to the docstring:

Suggested change
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
bottom_right_diagonal: Optional[bool], default = `None`
Align sliding window and ALiBi diagonal to the top left (`False`)
or bottom right (`True`) corner of the softmax matrix in the encoder.
If `None`, it will be set to `False` for `attn_mask_type` =
{'causal', 'padding_causal'} and `True` for other mask types.
Note: This parameter will be automatically overridden based on the
`attn_mask_type` - it will be forced to `False` for 'causal' and
'padding_causal' mask types, and forced to `True` for mask types
containing 'bottom_right' (e.g., 'causal_bottom_right',
'padding_causal_bottom_right'), regardless of the explicitly passed value.

This documentation makes it clear that the parameter can be overridden, helping users understand the actual behavior and avoiding confusion about when explicit values are respected.

Comment on lines +751 to +757
if attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic issue: Explicit bottom_right_diagonal parameter is unconditionally overridden

This has the same logic flaw as in dot_product_attention.py. When a user explicitly passes bottom_right_diagonal=False with attn_mask_type="causal_bottom_right", the explicit value gets overridden.

The code flow:

  1. Line 752: Sets bottom_right_diagonal = False for causal masks
  2. Line 753-757: Immediately overrides to True if mask type contains "bottom_right"

This means explicit parameter values are ignored. The logic should only set defaults when bottom_right_diagonal is None, not override explicit values.

Comment on lines +777 to +793
if self_attn_mask_type in {"causal", "padding_causal"}:
bottom_right_diagonal = False
if bottom_right_diagonal is None or self_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
bottom_right_diagonal = True

if enc_dec_bottom_right_diagonal is None:
enc_dec_bottom_right_diagonal = self.enc_dec_bottom_right_diagonal
if enc_dec_attn_mask_type in {"causal", "padding_causal"}:
enc_dec_bottom_right_diagonal = False
if enc_dec_bottom_right_diagonal is None or enc_dec_attn_mask_type in {
"causal_bottom_right",
"padding_causal_bottom_right",
}:
enc_dec_bottom_right_diagonal = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic issue: Explicit bottom_right_diagonal parameters are unconditionally overridden

This code has the same logic flaw as seen in other files. Both bottom_right_diagonal (lines 777-783) and enc_dec_bottom_right_diagonal (lines 787-793) suffer from the same issue where explicit parameter values get overridden.

When a user explicitly passes these parameters with specific mask types, the explicit values are ignored because:

  1. Lines 778, 788: Set to False for causal masks
  2. Lines 779-783, 789-793: Immediately override to True if mask type contains "bottom_right"

The logic should respect explicit values and only set defaults when the parameters are None.

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 13, 2026

Greptile's behavior is changing!

From now on, if a review finishes with no comments, we will not post an additional "statistics" comment to confirm that our review found nothing to comment on. However, you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

_attention_backends["backend_selection_requires_update"] = True
available_backends, flash_attention_backend, fused_attention_backend = test()
if fused_attention_backend == FusedAttnBackend[backends[i]]:
fused_attn_backends.append(fused_attention_backend)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change intended, or it's for debugging? Would this fix the printing issue we talked about before, where the unit test wouldn't print the selection logic from this for loop?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apparently, its use here is a bit problematic and so removing it helps print the information. (I vaguely remember that you suggested me to remove it and then I checked logging documentation and you were right)

 with logging_context(highest_level=AttentionLogging._log_level):

fu_core_attention_bias_shape = "1hss"
elif len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:

if len(alibi_slopes_shape) == 1 and alibi_slopes_shape[0] == num_heads:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is alibi_slopes_shape is None branch removed here?

Copy link
Collaborator Author

@sudhakarsingh27 sudhakarsingh27 Jan 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The encapsulating if condition makes sure that this branch is always false

if(
...
and (alibi_slopes_shape is not None)
):

Signed-off-by: Sudhakar Singh <[email protected]>
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

17 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

in [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q
+ window_size[1]] inclusive. Special cases (-1, -1) and (-1, 0) mean no sliding
window and causal mask specifically.
bottom_right_diagonal: bool, default = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Documentation inconsistency: forward function documents default as None but backward documents True

Suggested change
bottom_right_diagonal: bool, default = True
bottom_right_diagonal: bool, default = None

0,
0,
true,
true,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Critical: Hardcoded parameter instead of proper API parameter

This hardcoded true value should be replaced with a proper bottom_right_diagonal parameter passed from the function signature. The current implementation bypasses the intended parameter passing mechanism.

Issue: The function fused_attn_fp8_fwd_impl_v1 should accept bottom_right_diagonal as a parameter and pass it here, rather than hardcoding true. This breaks the configurability that the PR is supposed to add. Should this hardcoded true be replaced with a bottom_right_diagonal parameter that gets passed through the function signature?

0,
0,
true,
false,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: Critical: Hardcoded parameter instead of proper API parameter

Similar to the forward pass, this hardcoded false value should be replaced with a proper bottom_right_diagonal parameter passed from the function signature.

Issue: The function fused_attn_fp8_bwd_impl_v1 should accept bottom_right_diagonal as a parameter and pass it here. The hardcoded false suggests different behavior for backward pass, but this should still be configurable. Should this hardcoded false be replaced with a bottom_right_diagonal parameter, and why is it different from the forward pass value?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants